# -*- coding:utf-8 -*-

import tensorflow as tf

"""
优化器相关
"""


def optimizers(optimizer_name, learning_rate, decay=0.9, momentum=0.9):
    """
    优化器选择
    :param optimizer_name:
    :param learning_rate:
    :param decay:
    :param momentum:
    :return:
    """
    if optimizer_name == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=momentum)
    elif optimizer_name == 'rmsprop':
        optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=decay, momentum=momentum)
    elif optimizer_name == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate)
    elif optimizer_name == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise ValueError('不支持该类型优化器')
    return optimizer


def get_optimizer(config, learning_rate, loss, global_step):
    """
    优化器入口
    :param config:
    :param learning_rate:
    :param loss:
    :param global_step:
    :return:
    """
    optimizer = optimizers(config.optimizer.name, learning_rate)

    if config.optimizer.use_clip:  # 使用clipping gradients
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 5)
        train_step = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)
    else:  # 不使用clipping gradients
        train_step = optimizer.minimize(loss, global_step=global_step)
    return train_step

